// Copyright 2021-present StarRocks, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package com.starrocks.sql.optimizer;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.starrocks.sql.optimizer.base.ColumnRefSet;
import com.starrocks.sql.optimizer.operator.ColumnOutputInfo;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
import com.starrocks.sql.optimizer.rewrite.ReplaceColumnRefRewriter;

import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.StringJoiner;
import java.util.stream.Collectors;

/**
 * RowOutputInfo is used to describe the info of output columns returned by an operator.
 * It consists of a set of columnOutputInfo. Because of the Projection field in Operator,
 * an operator with a not null projection may take the original output of this operator
 * and project it to a new output.
 * <p>
 * To unify the output info of an operator, we use the RowOutputInfo to describe the output
 * row of this operator.
 * When an operator with a not null projection, the RowOutputInfo records the projection info
 * and the output info of the operator itself.
 * When an operator without a not null projection, the RowOutInfo only records the set of
 * columnOutputInfo of itself.
 */
public class RowOutputInfo {

    // store the final output of the optExpression after projection(if exists)
    private final Map<Integer, ColumnOutputInfo> colOutputInfo;

    // store the final common exprs referenced by the output
    private final Map<Integer, ColumnOutputInfo> commonColInfo;

    // store the output info of an operator itself
    private final Map<Integer, ColumnOutputInfo> originalColOutputInfo;

    // store the cols generated by operator itself without using input cols.
    private final Set<ColumnRefOperator> endogenousCols;

    public static RowOutputInfo createEmptyInfo() {
        return new RowOutputInfo();
    }

    private RowOutputInfo() {
        this.colOutputInfo = Maps.newHashMap();
        this.commonColInfo = Maps.newHashMap();
        this.originalColOutputInfo = Maps.newHashMap();
        this.endogenousCols = Sets.newHashSet();
    }

    public RowOutputInfo(Collection<ColumnOutputInfo> columnEntries) {
        this(columnEntries, Lists.newArrayList());
    }

    public RowOutputInfo(Collection<ColumnOutputInfo> columnEntries, Collection<ColumnRefOperator> endogenousCols) {
        this.originalColOutputInfo = Maps.newHashMap();
        for (ColumnOutputInfo col : columnEntries) {
            originalColOutputInfo.put(col.getColId(), col);
        }
        this.colOutputInfo = Maps.newHashMap();
        this.commonColInfo = Maps.newHashMap();
        this.endogenousCols = Sets.newHashSet(endogenousCols);
    }


    public RowOutputInfo(Map<ColumnRefOperator, ScalarOperator> columnRefMap) {
        this.originalColOutputInfo = Maps.newHashMap();
        for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : columnRefMap.entrySet()) {
            originalColOutputInfo.put(entry.getKey().getId(), new ColumnOutputInfo(entry));
        }
        this.colOutputInfo = Maps.newHashMap();
        this.commonColInfo = Maps.newHashMap();
        this.endogenousCols = Sets.newHashSet();
    }

    public RowOutputInfo(Map<ColumnRefOperator, ScalarOperator> columnRefMap,
                         Map<ColumnRefOperator, ScalarOperator> commonColMap) {
        this(columnRefMap, commonColMap, Maps.newHashMap(), Lists.newArrayList());
    }

    public RowOutputInfo(Map<ColumnRefOperator, ScalarOperator> columnRefMap,
                         Map<ColumnRefOperator, ScalarOperator> commonColMap,
                         Map<Integer, ColumnOutputInfo> originalColOutputInfo,
                         Collection<ColumnRefOperator> endogenousCols) {
        this.colOutputInfo = Maps.newHashMap();
        for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : columnRefMap.entrySet()) {
            colOutputInfo.put(entry.getKey().getId(), new ColumnOutputInfo(entry));
        }
        this.commonColInfo = Maps.newHashMap();
        for (Map.Entry<ColumnRefOperator, ScalarOperator> entry : commonColMap.entrySet()) {
            commonColInfo.put(entry.getKey().getId(), new ColumnOutputInfo(entry));
        }
        this.originalColOutputInfo = originalColOutputInfo;
        this.endogenousCols = Sets.newHashSet(endogenousCols);
    }


    public Set<ColumnRefOperator> getEndogenousCols() {
        return endogenousCols;
    }

    public Map<Integer, ColumnOutputInfo> getColOutputInfo() {
        return colOutputInfo;
    }

    public Map<Integer, ColumnOutputInfo> getOriginalColOutputInfo() {
        return originalColOutputInfo;
    }

    public List<ColumnOutputInfo> getColumnOutputInfo() {
        return Lists.newArrayList(chooseOutputMap().values());
    }

    public List<ColumnOutputInfo> getCommonColInfo() {
        return Lists.newArrayList(commonColInfo.values());
    }

    public List<ColumnRefOperator> getOutputColRefs() {
        return chooseOutputMap().values().stream().map(e -> e.getColumnRef()).collect(Collectors.toList());
    }

    public Map<ColumnRefOperator, ScalarOperator> getColumnRefMap() {
        return chooseOutputMap().values().stream()
                .collect(Collectors.toMap(ColumnOutputInfo::getColumnRef, ColumnOutputInfo::getScalarOp));
    }

    public ColumnRefSet getOutputColumnRefSet() {
        return ColumnRefSet.createByIds(chooseOutputMap().keySet());
    }



    // used to obtain all the used cols from the children of an optExpression.
    public ColumnRefSet getUsedColumnRefSet() {
        ColumnRefSet columnRefSet = new ColumnRefSet();
        if (originalColOutputInfo.isEmpty()) {
            // this can only happen when the operator of an optExpression is project
            for (ColumnOutputInfo col : colOutputInfo.values()) {
                columnRefSet.union(col.getUsedColumns());
            }
            for (ColumnOutputInfo col : commonColInfo.values()) {
                columnRefSet.union(col.getUsedColumns());
            }

            for (ColumnOutputInfo col : commonColInfo.values()) {
                columnRefSet.except(col.getColumnRef().getUsedColumns());
            }

        } else {
            for (ColumnOutputInfo col : originalColOutputInfo.values()) {
                columnRefSet.union(col.getUsedColumns());
            }
        }

        for (ColumnRefOperator col : endogenousCols) {
            columnRefSet.except(col.getUsedColumns());
        }
        return columnRefSet;
    }

    public ColumnOutputInfo rewriteColWithRowInfo(ColumnOutputInfo columnOutputInfo) {
        ReplaceColumnRefRewriter rewriter = new ReplaceColumnRefRewriter(getColumnRefMap());
        return new ColumnOutputInfo(columnOutputInfo.getColumnRef(), rewriter.rewrite(columnOutputInfo.getScalarOp()));
    }

    public RowOutputInfo addColsToRow(List<ColumnOutputInfo> entryList, boolean existProjection) {
        List<ColumnOutputInfo> newCols = Lists.newArrayList();
        if (existProjection) {
            newCols.addAll(getColumnOutputInfo());
            for (ColumnOutputInfo entry : entryList) {
                ColumnOutputInfo newEntry = rewriteColWithRowInfo(entry);
                newCols.add(newEntry);
            }
        } else {
            for (ColumnOutputInfo entry : getColumnOutputInfo()) {
                newCols.add(new ColumnOutputInfo(entry.getColumnRef(), entry.getColumnRef()));
            }
            newCols.addAll(entryList);
        }
        return new RowOutputInfo(newCols);
    }

    @Override
    public int hashCode() {
        return getOutputColumnRefSet().hashCode();
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof RowOutputInfo)) {
            return false;
        }

        RowOutputInfo that = (RowOutputInfo) obj;

        return Objects.equals(getOutputColumnRefSet(), that.getOutputColumnRefSet());
    }

    @Override
    public String toString() {
        StringJoiner joiner = new StringJoiner(", ", "[", "]");
        for (ColumnOutputInfo entry : colOutputInfo.values()) {
            joiner.add(entry.toString());
        }
        return joiner.toString();
    }

    private Map<Integer, ColumnOutputInfo> chooseOutputMap() {
        if (colOutputInfo.isEmpty()) {
            return originalColOutputInfo;
        } else {
            return colOutputInfo;
        }
    }
}
